#pragma once

#include <cassert>
#include <utility>

namespace simulation {
	enum class Color {
		RED,
		BLACK
	};

	template<class T>
	struct RBTreeNode {
		RBTreeNode<T>* _left;
		RBTreeNode<T>* _right;
		RBTreeNode<T>* _parent;

		T _data;

		Color _col;

		RBTreeNode(const T& data)
			: _left(nullptr)
			, _right(nullptr)
			, _parent(nullptr)
			, _data(data)
			, _col(Color::RED)
		{}
	};

	//T, T&, T*
	template<class T, class Ref, class Ptr>
	class RBTreeIterator {
	private:
		typedef RBTreeNode<T> Node;
		Node* _node;

		typedef RBTreeIterator<T, Ref, Ptr> Self;

	public:
		RBTreeIterator(Node* node)
			: _node(node)
		{}

		Ref operator* ()
		{
			return _node->_data;
		}

		Ptr operator-> ()
		{
			return &(_node->_data);
		}

		bool operator== (const Self& self) const
		{
			return _node == self._node;
		}

		bool operator!= (const Self& self) const
		{
			return _node != self._node;
		}

		Self& operator++ ()
		{
			if (_node->_right)
			{
				Node* subRL = _node->_right;
				while (subRL->_left)
				{
					subRL = subRL->_left;
				}

				_node = subRL;
			}
			else
			{
				Node* cur = _node;
				Node* parent = cur->_parent;
				while (parent && cur == parent->_right)
				{
					cur = parent;
					parent = cur->_parent;
				}

				_node = parent;
			}

			return *this;
		}

		Self& operator-- ()
		{
			if (_node->_left)
			{
				Node* subLR = _node->_left;
				while (subLR->_right)
				{
					subLR = subLR->_right;
				}

				_node = subLR;
			}
			else
			{
				Node* cur = _node;
				Node* parent = _node->_parent;

				while (parent && cur == parent->_left)
				{
					cur = parent;
					parent = cur->_parent;
				}

				_node = parent;
			}

			return *this;
		}
	};

	template<class K, class T, class KOfD>
	class RBTree {
	private:
		typedef RBTreeNode<T> Node;
		Node* _root;

	public:
		typedef RBTreeIterator<T, T&, T*> iterator;
		typedef RBTreeIterator<T, const T&, const T*> const_iterator;

	public:
		RBTree()
			: _root(nullptr)
		{}

		std::pair<iterator, bool> insert(const T& data)
		{
			//锟斤拷锟斤拷
			if (_root == nullptr)
			{
				_root = new Node(data);
				_root->_col = Color::BLACK; //锟斤拷锟斤拷锟斤拷母锟轿拷锟缴�
				return std::make_pair(iterator(_root), true);
			}

			KOfD getKeyOf;
			Node* cur = _root;
			Node* parent = nullptr;

			while (cur)
			{
				if (getKeyOf(data) > getKeyOf(cur->_data))
				{
					parent = cur;
					cur = cur->_right;
				}
				else if (getKeyOf(data) < getKeyOf(cur->_data))
				{
					parent = cur;
					cur = cur->_left;
				}
				else
				{
					return std::make_pair(iterator(cur), false);
				}
			}

			cur = new Node(data);
			getKeyOf(data) < getKeyOf(parent->_data)
				? parent->_left = cur
				: parent->_right = cur;
			cur->_parent = parent;

			Node* newNode = cur; //锟斤拷录锟铰诧拷锟斤拷诘锟斤拷位锟斤拷

			//锟斤拷锟斤拷锟斤拷色
			//锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷grandparent一锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷色为锟节★拷锟斤拷为锟斤拷锟斤拷锟斤拷锟斤拷冢锟絧arent锟斤拷为锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷母锟斤拷锟斤拷锟斤拷锟轿拷锟缴拷锟�
			//锟斤拷锟絞randparent为锟斤拷色锟酵诧拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟侥猴拷色锟节点）
			while (parent && parent->_col == Color::RED)
			{
				Node* grandparent = parent->_parent;

				if (parent == grandparent->_left)
				{
					Node* uncle = grandparent->_right;

					if (uncle && uncle->_col == Color::RED)
					{
						//锟斤拷锟揭伙拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷为锟斤拷
						parent->_col = uncle->_col = Color::BLACK;
						grandparent->_col = Color::RED;
					}
					else
					{
						//锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷宀伙拷锟斤拷诨锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷为锟斤拷
						//锟斤拷锟斤拷锟斤拷宀伙拷锟斤拷冢锟斤拷锟絚ur一锟斤拷锟斤拷锟斤拷为锟斤拷锟斤拷锟斤拷锟斤拷锟节点导锟铰的猴拷色
						//锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷为锟节ｏ拷锟斤拷么cur一锟斤拷锟斤拷锟斤拷锟斤拷为锟斤拷锟斤拷锟斤拷锟节碉拷锟侥拷锟轿拷斓硷拷碌暮锟缴拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷一锟戒化锟斤拷锟斤拷锟斤拷
						if (cur == parent->_right) //双锟斤拷
						{
							//双锟斤拷锟斤拷锟皆匡拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟饺达拷锟斤拷一锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟诫单锟斤拷锟斤拷锟斤拷锟揭伙拷锟斤拷锟�
							rotateL(parent);

							//锟斤拷锟斤拷转锟斤拷一锟轿猴拷parent锟斤拷cur锟斤拷指锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷要锟斤拷指锟斤拷一锟斤拷锟剿ｏ拷锟斤拷锟斤拷锟斤拷锟斤拷锟揭伙拷锟�
							std::swap(cur, parent);
						}

						rotateR(grandparent);
						cur->_col = Color::BLACK;
						grandparent->_col = Color::RED;

						break; //锟斤拷转锟疥，锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷色锟斤拷锟斤拷确锟斤拷
					}
				}
				else //parent == grandparent->right
				{
					Node* uncle = grandparent->_left;

					if (uncle && uncle->_col == Color::RED)
					{
						parent->_col = uncle->_col = Color::BLACK;
						grandparent->_col = Color::RED;
					}
					else
					{
						if (cur == parent->_left)
						{
							rotateR(parent);
							std::swap(cur, parent);
						}

						rotateL(grandparent);
						cur->_col = Color::BLACK;
						grandparent->_col = Color::RED;

						break;
					}
				}

				//锟斤拷锟斤拷锟斤拷锟较革拷锟斤拷
				cur = grandparent;
				parent = cur->_parent;
			}

			_root->_col = Color::BLACK;
			return std::make_pair(iterator(newNode), true);
		}

		iterator begin()
		{
			Node* cur = _root;
			while (cur && cur->_left)
			{
				cur = cur->_left;
			}

			return iterator(cur);
		}

		const_iterator begin() const
		{
			Node* cur = _root;
			while (cur && cur->_left)
			{
				cur = cur->_left;
			}

			return const_iterator(cur);
		}

		iterator end()
		{
			return iterator(nullptr);
		}

		const_iterator end() const
		{
			return const_iterator(nullptr);
		}

		iterator find(const K& key)
		{
			KOfD getKeyOf;
			Node* cur = _root;

			while (cur)
			{
				if (key > getKeyOf(cur->_data))
				{
					cur = cur->_right;
				}
				else if (key < getKeyOf(cur->_data))
				{
					cur = cur->_left;
				}
				else
				{
					return iterator(cur);
				}
			}

			return end();
		}

		const_iterator find(const K& key) const
		{
			KOfD getKeyOf;
			Node* cur = _root;

			while (cur)
			{
				if (key > getKeyOf(cur->_data))
				{
					cur = cur->_right;
				}
				else if (key < getKeyOf(cur->_data))
				{
					cur = cur->_left;
				}
				else
				{
					return iterator(cur);
				}
			}

			return end();
		}

	private:
		void rotateL(Node* parent)
		{
			assert(parent);

			Node* subR = parent->_right;
			Node* subRL = subR->_left;
			Node* grandparent = parent->_parent;

			parent->_right = subRL;
			if (subRL)
			{
				subRL->_parent = parent;
			}

			subR->_left = parent;
			parent->_parent = subR;

			if (_root == parent)
			{
				_root = subR;
			}
			else
			{
				grandparent->_left == parent
					? grandparent->_left = subR
					: grandparent->_right = subR;
			}

			subR->_parent = grandparent;
		}

		void rotateR(Node* parent)
		{
			assert(parent);
			Node* subL = parent->_left;
			Node* subLR = subL->_right;
			Node* grandparent = parent->_parent;

			parent->_left = subLR;
			if (subLR)
			{
				subLR->_parent = parent;
			}

			subL->_right = parent;
			parent->_parent = subL;

			if (parent == _root)
			{
				_root = subL;
			}
			else
			{
				grandparent->_left == parent
					? grandparent->_left = subL
					: grandparent->_right = subL;
			}

			subL->_parent = grandparent;
		}
	};
}